package work.javac.common.database;

import org.postgresql.copy.CopyIn;
import org.postgresql.copy.CopyManager;
import org.postgresql.copy.CopyOut;
import org.postgresql.core.BaseConnection;
import work.javac.bean.Column;
import work.javac.bean.Index;
import work.javac.bean.Partition;
import work.javac.bean.Table;
import work.javac.common.ConfigUtil;
import work.javac.common.Log;
import work.javac.common.database.utils.CopyUtil;
import work.javac.common.database.utils.DBUtils;

import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.StringReader;
import java.sql.*;
import java.util.*;
import java.util.stream.Collectors;

public abstract class SQLExecuteBase {

    private static final String ORACLE = "Oracle";
    private static final String DB2 = "Db2";
    private static final String POSTGRESQL = "PostgreSQL";

    public static int executeSQL(Connection conn, String sql, Object[] datas) throws SQLException {
        int size;
        try (PreparedStatement pstmt = conn.prepareStatement(sql)) {
            if (datas != null) {
                for (int i = 0; i < datas.length; i++) {
                    if (datas[i] == null) {
                        pstmt.setNull(i + 1, 0);
                    } else {
                        pstmt.setObject(i + 1, datas[i]);
                    }
                }
            }
            size = pstmt.executeUpdate();
        } catch (SQLException e) {
            Log.err(e.getMessage());
            throw e;
        }
        return size;
    }

    private static String getDataBaseType(Connection conn) {
        String dataBaseType = null;
        try {
            String databaseProductName = conn.getMetaData().getDatabaseProductName().toUpperCase();
            if (databaseProductName.contains(ORACLE.toUpperCase())) {
                dataBaseType = ORACLE;
            } else if (databaseProductName.contains(DB2.toUpperCase())) {
                dataBaseType = DB2;
            } else if (databaseProductName.contains(POSTGRESQL.toUpperCase())) {
                dataBaseType = POSTGRESQL;
            }
        } catch (SQLException e) {
            Log.err(e.getMessage());
        }
        return dataBaseType;
    }

    private static void isPostgreSQL(Connection conn) {
        if (isOtherSQL(conn)) {
            Log.err("对不起,此数据源不支持该功能,该功能只适用于%s语法", POSTGRESQL);
            throw new RuntimeException();
        }
    }

    private static boolean isOtherSQL(Connection conn) {
        return !getDataBaseType(conn).equals(POSTGRESQL);
    }

    public static SQLExecuteBase getDBInstance(Connection conn) {
        String pkg = SQLExecuteBase.class.getPackage().getName();
        String dataBaseType = getDataBaseType(conn);
        String className = pkg + ".execute." + dataBaseType + "Execute";
        SQLExecuteBase sqlExecuteBase = null;
        try {
            Class sqlExecute = Class.forName(className);
            sqlExecuteBase = (SQLExecuteBase) sqlExecute.newInstance();
        } catch (IllegalAccessException | InstantiationException | ClassNotFoundException e) {
            Log.err(e.getMessage());
        }
        return sqlExecuteBase;
    }

    public static String executeString(Connection conn, String sql) throws SQLException {
        try (Statement stmt = conn.createStatement();
             ResultSet resultSet = stmt.executeQuery(sql)) {
            resultSet.next();
            return resultSet.getString(1);
        }
    }

    public static List<Map<String, Object>> executeListMap(Connection conn, String sql) throws SQLException {
        List<Map<String, Object>> dataList = new LinkedList<>();
        try (Statement stmt = conn.createStatement();
//             stmt.setFetchSize(100);
             ResultSet resultSet = stmt.executeQuery(sql)) {
            ResultSetMetaData metaData = resultSet.getMetaData();
            int columnCount = metaData.getColumnCount();
            Map<String, Object> data;
            while (resultSet.next()) {
                data = new HashMap(columnCount);
                for (int i = 1; i <= columnCount; i++) {
                    int columnType = metaData.getColumnType(i);
                    if (91 <= columnType && columnType <= 93) {
                        data.put(metaData.getColumnName(i).toUpperCase(), resultSet.getObject(i));
                    }
                    data.put(metaData.getColumnName(i).toUpperCase(), resultSet.getString(i));
                }
                dataList.add(data);
            }
//            resultSet.close();
        } catch (SQLException e) {
            Log.err(e.getMessage());
            throw e;
        }
        return dataList;
    }

    public static long syncTable(Connection fromConn, Connection toConn, CopyUtil util) throws SQLException {
        isPostgreSQL(toConn);
        long len = 0L;
        int commit = Integer.parseInt(ConfigUtil.get("sync.commit"));
        try (Statement stmt = fromConn.createStatement();
//             stmt.setFetchSize(100);
             ResultSet rs = stmt.executeQuery(util.getSQL())) {
            int columnCount = rs.getMetaData().getColumnCount();
//            CopyUtil copyUtil = new CopyUtil();
//            copyUtil.setTABLENAME(util.getTABLENAME());

            List<String> toColumns = getColumns(toConn, util.getTABLENAME()).stream().map(String::toUpperCase).collect(Collectors.toList());
            List<Integer> returnInt = new ArrayList<>();

            Log.info("开始同步表:%s", util.getTABLENAME());
            for (int i = 1; i <= columnCount; i++) {
                String columnName = rs.getMetaData().getColumnName(i);
                if(toColumns.contains(columnName)){
                    util.addCOLUMNNAME(columnName);
                }else{
                    returnInt.add(i);
                }
//                copyUtil.addCOLUMNNAME(rs.getMetaData().getColumnName(i));
            }
            while (rs.next()) {
                for (int i = 1; i <= columnCount; i++) {
                    if(returnInt.contains(i)){
                        continue;
                    }
                    if (rs.getObject(i) == null) {
                        util.addData("");
//                        copyUtil.addData("");
                    } else {
                        util.addData(rs.getObject(i));
//                        copyUtil.addData(rs.getObject(i));
                    }
                }
                util.endData();
//                copyUtil.endData();
                len++;
                if (len % commit == 0) {
                    executeCopy(toConn, util);
//                    executeCopy(toConn, copyUtil);
                    Log.info("表:%s,完成同步数据量:%s", util.getTABLENAME(), len);
                }
            }
//            rs.close();
            executeCopy(toConn, util);
//            executeCopy(toConn, copyUtil);
            Log.info("完成同步表:%s,总数据量:%s", util.getTABLENAME(), len);
        } catch (SQLException e) {
            Log.err(e.getMessage());
            throw e;
        }
        return len;
    }

    public static long executeCopy(Connection fromConn, Connection toConn, CopyUtil util) throws SQLException {
        long size;
        int timeout = Integer.parseInt(ConfigUtil.get("sync.timeout")) * 60 * 1000;
        SQLExecute.executeSQL(toConn, "SET statement_timeout = " + timeout);
        if (isOtherSQL(fromConn)) {
            size = syncTable(fromConn, toConn, util);
        } else {
            SQLExecute.executeSQL(fromConn, "SET statement_timeout = " + timeout);
            size = executeCopy2Copy(fromConn, toConn, util);
        }
        return size;
    }

    public static long executeCopy(Connection conn, CopyUtil util) throws SQLException {
        isPostgreSQL(conn);
        StringReader reader = null;
        long size = 0L;
        CopyManager cp = new CopyManager((BaseConnection) conn);
        try {
//            Log.info(util.getSQL());
//            Log.info(util.getCopySQL());
            reader = new StringReader(util.getData());
            size = cp.copyIn(util.getCopySQL(), reader);
            util.clearData();
        } catch (SQLException e) {
            if (e.getMessage().contains("end-of-copy")) {
                reader = new StringReader(util.getData().replace("\\", "\\\\"));
                try {
                    util.setNOESCAPING(false);
                    size = cp.copyIn(util.getCopySQL(), reader);
                    util.clearData();
                    return size;
                } catch (IOException ioe) {
                    Log.err(ioe.getMessage());
                }
            }
            Log.err(e.getMessage());
            throw e;
        } catch (IOException e) {
            Log.err(e.getMessage());
        } finally {
            if (reader != null) {
                reader.close();
            }
        }
        return size;
    }

    public static String getSchema(Connection conn) throws SQLException {
        String schema = DBUtils.getSchema(conn);
        if (schema != null) {
            return schema;
        }
        try {
            schema = conn.getSchema();
        } catch (AbstractMethodError e) {
            schema = conn.getMetaData().getUserName();
            return schema;
        }
        schema = "$user".equals(schema) ? conn.getMetaData().getUserName() : schema;
        if (schema == null) {
            throw new SQLException("schema可能为空,请检查配置是否存在问题!");
        }
        return schema;
    }

    public static List<String> getColumns(Connection conn, String tableName) throws SQLException {
        tableName = isOtherSQL(conn) ? tableName.toUpperCase() : tableName.toLowerCase();
        try (ResultSet columns = conn.getMetaData().getColumns(conn.getCatalog(), getSchema(conn), tableName, null)) {
            List<String> cols = new LinkedList();
            while (columns.next()) {
                cols.add(columns.getString("COLUMN_NAME"));
            }
            return cols;
        }
    }

    public static List<String> getSqlColumns(Connection conn, String sql) throws SQLException {
        List<String> cols = new LinkedList();
        try (Statement stmt = conn.createStatement();
             ResultSet rs = stmt.executeQuery(sql)) {
            int columnCount = rs.getMetaData().getColumnCount();
            for (int i = 1; i <= columnCount; i++) {
                cols.add(rs.getMetaData().getColumnName(i));
            }
        }
        return cols;
    }

    public static long executeCopy2Copy(Connection fromConn, Connection toConn, CopyUtil util) throws SQLException {
        isPostgreSQL(toConn);
        isPostgreSQL(fromConn);
        long size = 0L;
        int commit = Integer.parseInt(ConfigUtil.get("sync.commit"));
        CopyManager copyManagerOut = new CopyManager((BaseConnection) fromConn);
        CopyManager copyManagerIn = new CopyManager((BaseConnection) toConn);
        if (util.getTABLENAME() != null) {
//            getColumns(toConn, util.getTABLENAME()).forEach(util::addCOLUMNNAME);
            List<String> toColumns = getColumns(toConn, util.getTABLENAME());
            String sql = util.getSQL();
            List<String> fromColumns = getSqlColumns(fromConn, sql);
            List<String> fromColumnsR = new ArrayList<>(fromColumns);
            fromColumnsR.retainAll(toColumns);
            toColumns.retainAll(fromColumns);
            toColumns.forEach(util::addCOLUMNNAME);
            if(fromColumnsR.size() > 0 || sql.contains("SELECT * FROM ")){
                util.setSQL("SELECT " + util.getCOLUMNNAME() + sql.substring(sql.toUpperCase().indexOf(" FROM ")));
            }
        }
        util.setNOESCAPING(false);
        util.setTYPE("O");
        CopyOut copyOut = copyManagerOut.copyOut(util.getCopySQL());
        util.setTYPE("I");
        CopyIn copyIn = copyManagerIn.copyIn(util.getCopySQL());
        Log.info("开始同步表:%s", util.getTABLENAME());
        try {
            byte[] buf;
            while ((buf = copyOut.readFromCopy()) != null) {
                copyIn.writeToCopy(buf, 0, buf.length);
                size++;
                if(size % commit == 0){
                    Log.info("表:%s,预计同步数据量:%s", util.getTABLENAME(), size);
                }
            }
            copyIn.endCopy();
//            size = copyIn.endCopy();
            Log.info("完成同步表:%s,读取数据量:%s,写入数据量:%s", util.getTABLENAME(), copyOut.getHandledRowCount(), size);
        } catch (SQLException e) {
            Log.err(e.getMessage());
            throw e;
        } finally {
            if (copyOut.isActive()) {
                copyOut.cancelCopy();
            }
            if (copyIn.isActive()) {
                copyIn.cancelCopy();
            }
        }
        return size;
    }

    public static long executeCopyFileIn(Connection conn, CopyUtil util) throws SQLException, IOException {
        isPostgreSQL(conn);
        long size;
        int timeout = Integer.parseInt(ConfigUtil.get("sync.timeout")) * 60 * 1000;
        SQLExecute.executeSQL(conn, "SET statement_timeout = " + timeout);
        try (FileInputStream in = new FileInputStream(util.getFILEPATH())) {
            CopyManager copyManager = new CopyManager((BaseConnection) conn);
            size = copyManager.copyIn(util.getCopySQL(), in);
        } catch (SQLException e) {
            Log.err("从数据写入表错误");
            Log.err(e.getMessage());
            throw e;
        } catch (IOException e) {
            Log.err("文件读取错误");
            Log.err(e.getMessage());
            throw e;
        }
        return size;
    }

    public static long executeCopyFileOut(Connection conn, CopyUtil util) throws SQLException, IOException {
        isPostgreSQL(conn);
        long size;
        int timeout = Integer.parseInt(ConfigUtil.get("sync.timeout")) * 60 * 1000;
        SQLExecute.executeSQL(conn, "SET statement_timeout = " + timeout);
        try (FileOutputStream out = new FileOutputStream(util.getFILEPATH())) {
            CopyManager copyManager = new CopyManager((BaseConnection) conn);
            size = copyManager.copyOut(util.getCopySQL(), out);
        } catch (SQLException e) {
            Log.err("数据从表写出错误");
            Log.err(e.getMessage());
            throw e;
        } catch (IOException e) {
            Log.err("文件写出错误");
            Log.err(e.getMessage());
            throw e;
        }
        return size;
    }

    public static List<String> getTable(Connection conn, String table) throws SQLException {
        if (table != null) {
            table = isOtherSQL(conn) ? table.toUpperCase() : table.toLowerCase();
        }
        String schema = getSchema(conn);
        try (ResultSet tables = conn.getMetaData().getTables(conn.getCatalog(), schema, table, new String[]{"TABLE"})) {
            List<String> tableList = new LinkedList<>();
            while (tables.next()) {
                tableList.add(tables.getString("TABLE_NAME"));
            }
//            tables.close();
            if (tableList.size() == 0) {
                Log.err(String.format("当前schema:%s可能不存在表%s", schema, table));
            }
            return tableList;
        }
    }

    public static List<Table> getTables(Connection conn, String tab) throws SQLException {
        if (tab != null) {
            tab = isOtherSQL(conn) ? tab.toUpperCase() : tab.toLowerCase();
        }
        DatabaseMetaData databaseMetaData = conn.getMetaData();
        String catalog = conn.getCatalog();
        String schema = getSchema(conn);
        Map<String, List<Partition>> partitions = SQLExecute.getPartitions(conn);
        List<Table> tableList = new LinkedList<>();
        try (ResultSet tables = databaseMetaData.getTables(catalog, schema, tab, new String[]{"TABLE"})) {
            while (tables.next()) {
                Table table = new Table();
                String tableName = tables.getString("TABLE_NAME");
                table.setTableName(tableName.toUpperCase());
                table.setRemarks(tables.getString("REMARKS"));
                try (ResultSet columns = databaseMetaData.getColumns(catalog, schema, tableName, null);
                     ResultSet primaryKeys = databaseMetaData.getPrimaryKeys(catalog, schema, tableName);
                     ResultSet indexInfo = databaseMetaData.getIndexInfo(catalog, schema, tableName, false, true)) {
                    while (columns.next()) {
                        Column column = new Column();
                        column.setColumnDef(columns.getString("COLUMN_DEF"));
                        column.setColumnName(columns.getString("COLUMN_NAME").toUpperCase());
                        column.setTypeName(columns.getString("TYPE_NAME").toUpperCase());
                        column.setColumnSize(columns.getInt("COLUMN_SIZE"));
                        column.setDecimalDigits(columns.getInt("DECIMAL_DIGITS"));
                        column.setNullable(columns.getInt("NULLABLE"));
                        column.setRemarks(columns.getString("REMARKS"));
                        table.addColumn(column);
                    }
                    while (primaryKeys.next()) {
                        table.setPkName(primaryKeys.getString("PK_NAME").toUpperCase());
                        table.addPrimaryKey(primaryKeys.getString("COLUMN_NAME").toUpperCase());
                    }
                    while (indexInfo.next()) {
                        String indexName = indexInfo.getString("INDEX_NAME");
                        String columnName = indexInfo.getString("COLUMN_NAME");
                        if (indexName == null || indexName.equalsIgnoreCase(table.getPkName())) {
                            continue;
                        }
                        Index index = new Index();
                        index.setIndexName(indexName.toUpperCase());
                        index.addColumn(columnName.toUpperCase());
                        table.addIndex(index);
                    }
                    table.setPartitions(partitions.get(tableName));
                    tableList.add(table);
                    Log.info("%s表; 字段数:%s; 索引数:%s; 分区数:%s", table.getTableName(), table.getColumns().size(), table.getIndexs().size(), table.getPartitions() == null ? 0 : table.getPartitions().size());
                }
            }
        }
        if (tableList.size() == 0) {
            Log.err(String.format("当前schema:%s可能不存在表%s", schema, tab));
        }
        return tableList;
    }

    public static List<String> getTableDef(Connection conn, String tab) throws SQLException {
        isPostgreSQL(conn);
        List<String> table = getTable(conn, tab);
        List<String> datas = new LinkedList<>();
        String sql = "SELECT PG_GET_TABLEDEF('%s')";
        table.forEach(item -> {
            try {
                Log.infof("获取%s表结构", item);
                String data = executeString(conn, String.format(sql, item));
                Log.infon(" 完成!");
                datas.add(data);
            } catch (SQLException e) {
                Log.err(e.getMessage());
            }
        });
        return datas;
    }


    public static List<String> getSkewness(Connection conn, String tab) throws SQLException {
        isPostgreSQL(conn);
        List<String> table = getTable(conn, tab);
        List<String> datas = new LinkedList<>();
        String sql = "SELECT CEIL(MAX(CT) / MIN(CT)) X FROM (SELECT COUNT(1) CT FROM %s GROUP BY XC_NODE_ID)";
        table.forEach(item -> {
            try {
                Log.infof("开始对%s表进行数据倾斜分析!", item);
                String skewness = executeString(conn, String.format(sql, item));
                if (skewness == null) {
                    datas.add(String.format("表:%s; 可能不存在数据", item));
                } else if (Integer.parseInt(skewness) > 5) {
                    datas.add(String.format("表:%s; 倾斜率:%s; 可能存在数据倾斜!", item, skewness));
                } else {
                    datas.add(String.format("表:%s; 倾斜率:%s", item, skewness));
                }
                Log.infon(" 倾斜率:%s", skewness);
            } catch (SQLException e) {
                Log.err(e.getMessage());
            }
        });
        return datas;
    }

    public abstract String getViews(Connection conn) throws SQLException;

    public abstract Map<String, List<Partition>> getPartitions(Connection conn) throws SQLException;

}
